Hvad er den nemmeste måde at omdanne tensor af form (batch_størrelse, højde, bredde) fyldt med n-værdier til tensor af form (batch_størrelse, n, højde, bredde)? Jeg oprettede løsningen nedenfor, men ser ud til at der er en lettere og hurtigere måde at gøre dette på def batch_tensor_to_onehot (tnsr, klasser): tnsr = tnsr.unsqueeze (1) res = [] for cls inden for rækkevidde (klasser): res.append ((tnsr == cls). lang ()) return torch.cat (res, dim = 1)
2021-02-20 08:18:16
Du kan bruge torch.nn.functional.one_hot. Til din sag: a = fakkel.nn.functional.one_hot (tnsr, num_classes = klasser) out = a.permute (0, 3, 1, 2) | Du kan også bruge Tensor.scatter_, som undgår .permute, men uden tvivl vanskeligere at forstå end den ligefremme metode, der er foreslået af @Alpha. def batch_tensor_to_onehot (tnsr, klasser): resultat = fakkel.zeros (tnsr.shape [0], klasser, * tnsr.shape [1:], dtype = fakkel.long, enhed = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) returneresultat Benchmarking-resultater Jeg var nysgerrig og besluttede at sammenligne de tre tilgange. Jeg fandt ud af, at der ikke ser ud til at være en signifikant relativ forskel mellem de foreslåede metoder med hensyn til batchstørrelse, bredde eller højde. Primært var antallet af klasser den kendetegnende faktor. Naturligvis kan det som med enhver benchmark-kilometertal variere. Benchmarks blev indsamlet ved anvendelse af tilfældige indekser og ved anvendelse af batchstørrelse, højde, bredde = 100. Hvert eksperiment blev gentaget 20 gange med gennemsnittet rapporteret. Num_classes = 100 eksperimentet køres en gang før profilering til opvarmning. CPU-resultaterne viser, at den oprindelige metode sandsynligvis var bedst for num_klasser mindre end ca. 30, mens for GPU synes scatter_-tilgangen at være hurtigst. Test udført på Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K Koden, der bruges til benchmarking, er angivet nedenfor: importer fakkel fra tqdm importer tqdm importtid importer matplotlib.pyplot som plt def batch_tensor_to_onehot_slavka (tnsr, klasser): tnsr = tnsr.unsqueeze (1) res = [] for cls inden for rækkevidde (klasser): res.append ((tnsr == cls). lang ()) return torch.cat (res, dim = 1) def batch_tensor_to_onehot_alpha (tnsr, klasser): resultat = fakkel.nn.functional.one_hot (tnsr, num_classes = klasser) returneresultat. slå (0, 3, 1, 2) def batch_tensor_to_onehot_jodag (tnsr, klasser): resultat = fakkel.zeros (tnsr.shape [0], klasser, * tnsr.shape [1:], dtype = fakkel.long, enhed = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) returneresultat def hoved (): num_classes = [2, 10, 25, 50, 100] højde = 100 bredde = 100 bs = [100] * 20 for d i ['cpu', 'cuda']: times_slavka = [] times_alpha = [] times_jodag = [] opvarmning = Sandt for c i tqdm ([num_classes [-1]] + num_classes, ncols = 0): tslavka = 0 talpha = 0 tjodag = 0 for b i bs: tnsr = fakkel. rand (c, (b, højde, bredde)). til (enhed = d) t0 = tid.tid () y = batch_tensor_to_onehot_slavka (tnsr, c) torch.cuda.synchronize () tslavka + = time.time () - t0 hvis ikke opvarmning: times_slavka.append (tslavka / len (bs)) for b i bs: tnsr = fakkel. rand (c, (b, højde, bredde)). til (enhed = d) t0 = tid.tid () y = batch_tensor_to_onehot_alpha (tnsr, c) torch.cuda.synchronize () talpha + = tid.tid () - t0 hvis ikke opvarmning: times_alpha.append (talpha / len (bs)) for b i bs: tnsr = fakkel. rand (c, (b, højde, bredde)). til (enhed = d) t0 = tid.tid () y = batch_tensor_to_onehot_jodag (tnsr, c) torch.cuda.synchronize () tjodag + = time.time () - t0 hvis ikke opvarmning: times_jodag.append (tjodag / len (bs)) opvarmning = falsk fig = plt. figur () økse = fig. underplotter () ax.plot (num_classes, times_slavka, label = 'Slavka-cat') ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot') ax.plot (num_classes, times_jodag, label = 'jodag-scatter_') ax.set_xlabel ('num_classes') ax.set_ylabel ('tid (er)') ax.set_title (f '{d} benchmark') ax.legend () plt.savefig (f '{d} .png') plt.show () hvis __name__ == "__main__": hoved () | Dit svar StackExchange.ifUsing ("editor", funktion () { StackExchange.using ("externalEditor", funktion () { StackExchange.using ("uddrag", funktion () { StackExchange.snippets.init (); }); }); }, "kodestykke"); StackExchange.ready (funktion () { var channelOptions = { tags: "" .split (""), id: "1" }; initTagRenderer ("". split (""), "" .split (""), channelOptions); StackExchange.using ("externalEditor", funktion () { // Skal redigere editoren efter uddrag, hvis uddrag er aktiveret hvis (StackExchange.settings.snippets.snippetsEnabled) { StackExchange.using ("uddrag", funktion () { createEditor (); }); } andet { createEditor (); } }); funktion createEditor () { StackExchange.prepareEditor ({ useStacksEditor: falsk, heartbeatType: 'svar', autoActivateHeartbeat: false, convertImagesToLinks: sand, noModals: sandt, showLowRepImageUploadWarning: true, reputToPostImages: 10, bindNavPrevention: true, postfix: "", imageUploader: { brandingHtml: "Drevet af \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "height = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46.2665 7.94324 47.1084 7.58816C47.4091 7.46349 47.7169 7.36433 48.0099 7.26993C48.9099 6.97997 49.672 6.73443 49.672 5.93063C49.672 5.22043 48.9832 4.61182 48.1414 4.61182C47.4335 4.61182 46.7256 4.916 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.59049C41.5985 5.261 38.5948 5.28821 38.5948 6.59049V9.60062C38.5948 10.8521 38.2696 11.5455 37.0451 11.5455C35.8209 11.5455 35.4954 10.8521 35.4954 9.60062V6.59049C35.4954 5.28821 35.0173 4.66232 34.0034.4702366 fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.59049C30.9633 4.8821 30.5679 29.4502 4.66231C28.9913 4.66231 28.4555 4.94978 28.1109 5.50789C27.499 4.86533 26.7335 4.56087 25.7005 4.56087C23.1369 4.56087 21.0134 6.57349 21.0134 9.27932C21.0134 11.9852 23.003 13.913 25.3756.913 C28. 1256 12.8854 28,1301 12,9342 28,1301 12.983C28.1301 14,4373 27,2502 15,2321 25,777 15.2321C24.8349 15,2321 24,1352 14,9821 23,5661 14.7787C23.176 14,6393 22,8472 14,5218 22,5437 14.5218C21.7977 14,5218 21,2429 15,0123 21,2429 15.6887C21.2429 16,7375 22,9072 17,6335 25,6622 17.6335ZM24.1317 9,27932 C24.1317 7.94324 24.9928 7.09766 26.1024 7.09766C27.2119 7.09766 28.0918 7.94324 28.0918 9.27932C28.0918 10.6321 27.2311 11.5116 26.1024 11.5116C24.9737 11.5116 24.1317 10.6491 24.1317 9.27932Z \ "/ \ u003e \ u003e \ u003e \ u003e \ u003e 8045 13.2535 17.2637 13.8962 18.2965 13.8962C19.3298 13.8962 19.8079 13.2535 19.8079 11.9512V8.12928C19.8079 5.82936 18.4879 4.62866 16.4027 4.62866C15.1594 4.62866 14.279 4.98375 13.3609 5.880134125.56 58314 4.9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512.9.10.993 13.2535 11.3711 13.8962 12.4044 13.8962C13.4315.996 13.2535 13.3511 11.8711 13.8962 12.4044 13.8962C13.4375.996 C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/ \ u003e \ u003cpath d = \" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0.313354 5.28821.213351 1.82471 13.8962C2.85798 13.8962 3.31675 13.2535 3.31675 11.9512V6.59049Z \ "/ \ u003e \ u003cpath d = \" M1.87209 0.400291C0.843612 0.400291 0 1.1159 0 1.98861C0 2.87869 0.822846 3.57676 1.87209 3.5767.9 C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z \ "fill = \" # 1BB76E \ "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e", contentPolicyHtml: "Brugerbidrag licenseret under \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (content policy) \ u003c / a \ u003e", allowUrls: sandt }, onDemand: sandt, discardSelector: ".discard-answer" , straksShowMarkdownHelp: true, enableTables: true, enableSnippets: true }); } }); Tak for dit bidrag til Stack Overflow! Sørg for at besvare spørgsmålet. Giv detaljer og del din forskning! Men undgå ... At bede om hjælp, afklaring eller svar på andre svar. At afgive udsagn baseret på mening; sikkerhedskopier dem med referencer eller personlig erfaring. For at lære mere, se vores tip til at skrive gode svar. Kladde gemt Udkast kasseret Tilmeld dig eller log ind StackExchange.ready (funktion () { StackExchange.helpers.onClickDraftSave ('# login-link'); }); Tilmeld dig ved hjælp af Google Tilmeld dig via Facebook Tilmeld dig ved hjælp af e-mail og adgangskode Indsend Send som gæst Navn E-mail Påkrævet, men aldrig vist StackExchange.ready ( funktion () { StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' ); } ); Send som gæst Navn E-mail Påkrævet, men aldrig vist Send dit svar Kassér Ved at klikke på "Send dit svar" accepterer du vores servicevilkår, fortrolighedspolitik og cookiepolitik Er det ikke det svar, du leder efter? Gennemse andre spørgsmål med tagget python pytorch tensor en-hot-kodning, eller still dit eget spørgsmål.